Imports¶

In [1]:
import os
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

if not 'have_changed_cwd' in globals():
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
    os.chdir(PROJECT_ROOT)
    print(f"Working directory: {os.getcwd()}")
    have_changed_cwd = True


from src.models.classification.classifier import ImageClassifier
from src.dataset.cifar import CIFAR10DataModule
from src.models.autoencoding.autoencoder import Autoencoder
from src.utils import plot_image_grid

chkpt_path = "classifiers/resnet9-cifar10/2024-12-27/19-57-36/checkpoints/last.ckpt"
# chkpt_path = "classifiers/linear-cifar10/2024-12-27/20-41-01/checkpoints/last.ckpt"
classifier = ImageClassifier.load_from_checkpoint(chkpt_path)

chkpt_path = "autoencoders/convae-cifar10/2024-12-27/20-11-11/checkpoints/last.ckpt"
autoencoder = Autoencoder.load_from_checkpoint(chkpt_path)

datamodule = CIFAR10DataModule.get_default_dataset("cifar10", samples_per_class=100)

x, y = next(iter(datamodule.train_dataloader()))
N = 9
x = x[:N]
y = y[:N]
Working directory: /Users/mat/Desktop/Files/Code/Generative-Data-Augmentation
Files already downloaded and verified
Files already downloaded and verified

Classifier Predictions¶

In [58]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [15]:
from src.eval.classification import inference_with_classifier

out = inference_with_classifier(classifier, datamodule.test_dataloader(), device='mps')
Doing inference on 16 batches on mps: 100%|██████████| 16/16 [00:00<00:00, 41.14it/s]
In [16]:
from sklearn.metrics import classification_report

print(classification_report(out['target'], out['pred'], target_names=datamodule.class_names))
              precision    recall  f1-score   support

    airplane       0.89      0.74      0.81       100
  automobile       0.94      0.89      0.91       100
        bird       0.78      0.90      0.83       100
         cat       0.78      0.81      0.79       100
        deer       0.87      0.87      0.87       100
         dog       0.82      0.81      0.81       100
        frog       0.95      0.80      0.87       100
       horse       0.85      0.90      0.87       100
        ship       0.91      0.86      0.89       100
       truck       0.80      0.95      0.87       100

    accuracy                           0.85      1000
   macro avg       0.86      0.85      0.85      1000
weighted avg       0.86      0.85      0.85      1000

In [8]:
from src.eval.classification import collect_misclassified, plot_image_grid
from matplotlib import pyplot as plt

misclassified_images, true_labels, pred_labels, pred_logits = collect_misclassified(classifier, datamodule.val_dataloader(), device='mps', num_samples=9)
In [9]:
plot_image_grid(misclassified_images, true_labels, pred_logits, class_names=datamodule.class_names)
plt.show()

Autoencoder predictions¶

In [ ]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [10]:
from src.eval.autoencoding import inference_with_autoencoder, collect_high_error_reconstructions, plot_reconstruction_pairs

out = inference_with_autoencoder(autoencoder, datamodule.test_dataloader(), device='mps')
In [11]:
plt.hist(out['mse_losses'], bins=100)
plt.title(f"Reconstruction Error Histogram (MSE) - mean: {out['mse_losses'].mean():.4f}")
plt.grid()
plt.show()
In [12]:
original_inputs, reconstructions, latent_codes, mse_values = collect_high_error_reconstructions(autoencoder, datamodule.test_dataloader(), device='mps', threshold=0.015, num_samples=9)
In [13]:
plot_reconstruction_pairs(original_inputs, reconstructions, num_pairs=5, figsize=(8, 8), mse_losses=mse_values)
plt.show()

Gradients¶

In [ ]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]

Pixel Space¶

In [ ]:
num_steps = 300
lr = 0.001
weight_decay = 0.01
optimizer_cls = torch.optim.AdamW
In [ ]:
from src.input_gradients import compute_proba_grads_wrt_data, plot_grads_wrt_data

grads, _ = compute_proba_grads_wrt_data(classifier, x.clone(), list(range(10)), device='mps')
plot_grads_wrt_data(x, grads, list(range(10)), datamodule.class_names)
plt.show()
No description has been provided for this image
In [22]:
from src.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": num_steps,
    "optimizer_cls": optimizer_cls,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "optimizer_kwargs": {
        "lr": lr,
        "weight_decay": weight_decay,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
No description has been provided for this image
In [23]:
from src.input_gradients import plot_optimization_trajectory_fixed_sample, plot_optimization_trajectory_fixed_target

plot_optimization_trajectory_fixed_target(out, 2, datamodule.class_names)
plt.show()
No description has been provided for this image
No description has been provided for this image
In [24]:
plot_optimization_trajectory_fixed_sample(out, 0, datamodule.class_names, targets=[0, 1, 2], figsizes=((15, 8), (15, 12)))
plt.show()
No description has been provided for this image
No description has been provided for this image

Latent Space¶

In [38]:
num_steps = 500
lr = 0.001
weight_decay = 0.03
optimizer_cls = torch.optim.AdamW
In [ ]:
from src.input_gradients import compute_proba_grads_wrt_data, plot_grads_wrt_data

grads, finite_diffs = compute_proba_grads_wrt_data(classifier, x.clone(), list(range(10)), device='mps', autoencoder=autoencoder, epsilon=0.0001)
plot_grads_wrt_data(x, finite_diffs, list(range(10)), datamodule.class_names)
plt.show()
No description has been provided for this image
In [40]:
from src.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": num_steps,
    "optimizer_cls": optimizer_cls,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "optimizer_kwargs": {
        "lr": lr,
        "weight_decay": weight_decay,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
    autoencoder=autoencoder,
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
No description has been provided for this image
In [41]:
from src.input_gradients import plot_optimization_trajectory_fixed_sample, plot_optimization_trajectory_fixed_target

plot_optimization_trajectory_fixed_target(out, 2, datamodule.class_names)
plt.show()
No description has been provided for this image
No description has been provided for this image
In [42]:
plot_optimization_trajectory_fixed_sample(out, 0, datamodule.class_names, targets=[0, 1, 2], figsizes=((15, 8), (15, 12)))
plt.show()
No description has been provided for this image
No description has been provided for this image
In [ ]:
torch.save(out, "optimal_images.pth")
In [46]:
out_loaded = torch.load("optimal_images.pth")
/var/folders/jy/x5558th97mjgtzp6f33ryf840000gn/T/ipykernel_59460/595520248.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  out_loaded = torch.load("optimal_images.pth")

Flatness¶

In [2]:
x, y = next(iter(datamodule.val_dataloader()))
N = 3
x = x[:N]
y = y[:N]
In [12]:
from src.eval.flatness import compute_local_energy, plot_local_energy

local_energy = compute_local_energy(
    classifier,
    datamodule.test_dataloader(),
    1000,
    None,
    np.linspace(0.0, 0.5, 6),
    10,
    'mps'
)
Computing Local Energy for various noise levels: 100%|██████████| 6/6 [00:17<00:00,  2.84s/it]
In [ ]:
fig = plot_local_energy(local_energy)
plt.show()
No description has been provided for this image
In [3]:
from src.eval.flatness import compute_input_flatness, plot_inputs_flatness


results, gt = compute_input_flatness(
    classifier,
    [(x, y)],
    2,
    np.linspace(0.0, 1.0, 11),
    100,
    'mps'
)
Processing batches:   0%|          | 0/1 [00:03<?, ?it/s]
In [4]:
plot_inputs_flatness(results, gt, -1, class_names=datamodule.class_names, filter_misclassified=True, print_summary=False, plot_global_average=True, plot_individual_samples=True)
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [5]:
from src.eval.flatness import visualize_input_noise

visualize_input_noise(x)
plt.show()
No description has been provided for this image
In [6]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": 300,
    "optimizer_cls": torch.optim.AdamW,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "perturb_weights": True,
    "stddev": 0.2,
    "weights_sample_freq": 10,
    "optimizer_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.01,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
Optimizing target 0: 100%|██████████| 300/300 [00:02<00:00, 115.66it/s]
Optimizing target 1: 100%|██████████| 300/300 [00:02<00:00, 142.51it/s]
Optimizing target 2: 100%|██████████| 300/300 [00:02<00:00, 138.18it/s]
Optimizing target 3: 100%|██████████| 300/300 [00:02<00:00, 142.88it/s]
Optimizing target 4: 100%|██████████| 300/300 [00:02<00:00, 139.71it/s]
Optimizing target 5: 100%|██████████| 300/300 [00:02<00:00, 142.34it/s]
Optimizing target 6: 100%|██████████| 300/300 [00:02<00:00, 128.53it/s]
Optimizing target 7: 100%|██████████| 300/300 [00:02<00:00, 131.86it/s]
Optimizing target 8: 100%|██████████| 300/300 [00:02<00:00, 140.45it/s]
Optimizing target 9: 100%|██████████| 300/300 [00:02<00:00, 138.63it/s]
No description has been provided for this image

Genuine vs Adversarial examples¶

In [2]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [5]:
meanss = []
stdss = []
labels = []
In [3]:
plot_image_grid(x, y, datamodule.class_names, figsize=(4, 4))
plt.show()
No description has been provided for this image
In [4]:
from src.eval.flatness import compute_input_flatness, plot_inputs_flatness


results, gt = compute_input_flatness(
    classifier,
    datamodule.test_dataloader(),
    1000,
    np.linspace(0.0, 1.0, 11),
    9,
    'mps'
)
Processing batches:  94%|█████████▍| 15/16 [00:22<00:01,  1.51s/it]
In [6]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(results, gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Genuine Samples")
No description has been provided for this image
No description has been provided for this image

Adversarial pixel space¶

In [7]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": 300,
    "optimizer_cls": torch.optim.AdamW,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "perturb_weights": False,
    "stddev": 0.2,
    "weights_sample_freq": 10,
    "optimizer_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.01,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
)
Optimizing target 0:   0%|          | 0/300 [00:00<?, ?it/s]
Optimizing target 0: 100%|██████████| 300/300 [00:04<00:00, 64.61it/s]
Optimizing target 1: 100%|██████████| 300/300 [00:04<00:00, 67.04it/s]
Optimizing target 2: 100%|██████████| 300/300 [00:04<00:00, 67.97it/s]
Optimizing target 3: 100%|██████████| 300/300 [00:04<00:00, 67.38it/s]
Optimizing target 4: 100%|██████████| 300/300 [00:04<00:00, 67.71it/s]
Optimizing target 5: 100%|██████████| 300/300 [00:04<00:00, 67.63it/s]
Optimizing target 6: 100%|██████████| 300/300 [00:04<00:00, 67.06it/s]
Optimizing target 7: 100%|██████████| 300/300 [00:04<00:00, 64.29it/s]
Optimizing target 8: 100%|██████████| 300/300 [00:04<00:00, 67.96it/s]
Optimizing target 9: 100%|██████████| 300/300 [00:04<00:00, 67.17it/s]
In [8]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[8]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [9]:
adv_results, adv_gt = compute_input_flatness(
    classifier,
    [(adv_images, adv_labels)],
    None,
    np.linspace(0.0, 1.0, 11),
    9,
    'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00,  3.37s/it]
In [10]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Pixels")
No description has been provided for this image
No description has been provided for this image

Adversarial latent space¶

In [11]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": 600,
    "optimizer_cls": torch.optim.AdamW,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "perturb_weights": False,
    "stddev": 0.2,
    "weights_sample_freq": 10,
    "optimizer_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.01,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
    autoencoder=autoencoder,
)
Optimizing target 0: 100%|██████████| 600/600 [00:46<00:00, 13.04it/s]
Optimizing target 1: 100%|██████████| 600/600 [00:46<00:00, 13.03it/s]
Optimizing target 2: 100%|██████████| 600/600 [00:45<00:00, 13.17it/s]
Optimizing target 3: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s]
Optimizing target 4: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s]
Optimizing target 5: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s]
Optimizing target 6: 100%|██████████| 600/600 [00:46<00:00, 13.02it/s]
Optimizing target 7: 100%|██████████| 600/600 [00:45<00:00, 13.16it/s]
Optimizing target 8: 100%|██████████| 600/600 [00:45<00:00, 13.11it/s]
Optimizing target 9: 100%|██████████| 600/600 [00:45<00:00, 13.05it/s]
In [12]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[12]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [13]:
adv_results, adv_gt = compute_input_flatness(
    classifier,
    [(adv_images, adv_labels)],
    None,
    np.linspace(0.0, 1.0, 11),
    9,
    'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00,  3.57s/it]
In [14]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Latent")
No description has been provided for this image
No description has been provided for this image

With classifier noise¶

In [15]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images

config = {
    "num_steps": 300,
    "optimizer_cls": torch.optim.AdamW,
    "save_k_intermediate_imgs": 10,
    "logit_transform": None,
    "perturb_weights": True,
    "stddev": 0.2,
    "weights_sample_freq": 10,
    "optimizer_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.01,
    },
}

out = optimize_proba_wrt_data(
    classifier,
    x.clone(),
    list(range(10)),
    config,
    device='mps',
)
Optimizing target 0: 100%|██████████| 300/300 [00:04<00:00, 61.80it/s]
Optimizing target 1: 100%|██████████| 300/300 [00:04<00:00, 63.31it/s]
Optimizing target 2: 100%|██████████| 300/300 [00:04<00:00, 65.58it/s]
Optimizing target 3: 100%|██████████| 300/300 [00:04<00:00, 65.69it/s]
Optimizing target 4: 100%|██████████| 300/300 [00:04<00:00, 65.79it/s]
Optimizing target 5: 100%|██████████| 300/300 [00:04<00:00, 65.49it/s]
Optimizing target 6: 100%|██████████| 300/300 [00:04<00:00, 65.63it/s]
Optimizing target 7: 100%|██████████| 300/300 [00:04<00:00, 65.77it/s]
Optimizing target 8: 100%|██████████| 300/300 [00:04<00:00, 63.50it/s]
Optimizing target 9: 100%|██████████| 300/300 [00:04<00:00, 65.63it/s]
In [16]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[16]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [17]:
adv_results, adv_gt = compute_input_flatness(
    classifier,
    [(adv_images, adv_labels)],
    None,
    np.linspace(0.0, 1.0, 11),
    9,
    'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00,  3.43s/it]
In [18]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Pixel noisy weights")
No description has been provided for this image
No description has been provided for this image
In [25]:
fig, ax = plt.subplots(figsize=(8, 6))
colors = ["red", "blue", "yellow", "black"]
for means, stds, label, color in zip(meanss, stdss, labels, colors):
    ax.errorbar(np.linspace(0.0, 1.0, 11), means, yerr=stds, label=label, color=color)
plt.xticks(np.linspace(0.0, 1.0, 11))
plt.yticks(np.linspace(0.0, 1.0, 11))
plt.xlabel("Noise Level")
plt.ylabel("probability")
plt.title("Flatness of the input landscape")
plt.grid()
plt.legend()
plt.show()
No description has been provided for this image